import calibration as cal
from sklearn.linear_model import LogisticRegression
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
import random

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import IsolationForest
from sklearn.svm import OneClassSVM

from src.data import (
    load_wilds_dataset, 
    Dataset, 
    extract_dataset, 
    get_wilds_dataloader, 
    dataloader_from_dataset, 
    get_dataloader, 
    load_dataset_from_store,
)
from src.models import load_model

from dataclasses import dataclass


def seed_all(seed: int = 36):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    
save_data = [
    ("rxrx", "densenet", "val"),
    ("rxrx", "densenet", "id_test"),
    ("rxrx", "densenet", "test"),
    ("camelyon17", "densenet", "val"),
    ("camelyon17", "densenet", "id_val"),
    ("camelyon17", "densenet", "test"),
    ("cifar-100-c", "clip", "test"),
    ("cifar-100", "clip", "clean_val"),
    ("cifar-100", "clip", "mix_val"),
    ("cifar-100", "clip", "id_test"),
    ("imagenet", "clip", "test"),
    ("imagenet", "resnet", "test"),
    ("imagenet-v2", "resnet", "clean_val"),
    ("imagenet-v2", "resnet", "mix_val"),
    ("imagenet-v2", "resnet", "id_test"),
    ("imagenet-v2", "clip", "clean_val"),
    ("imagenet-v2", "clip", "mix_val"),
    ("imagenet-v2", "clip", "id_test"),
    ("imagenet-sketch", "resnet", "test"),
]


def clean_method_name(mtd):

    m = mtd.split("_quad")[0]

    m = m.replace("bce", "TLBCE")
    m = m.replace("smmce", "MMCE")
    m = m.replace("mce", "MCE")

    m = m.replace("_fixed", " fixed")
    m = m.replace("_optim", " optimized")
    
    m = m.replace("conf", "Confidence")
    m = m.replace("osvm", "One-class SVM")
    m = m.replace("isoforest", "Iso. Forest")
    
    m = m.replace("full-cal", "Full - Calibrated")
    m = m.replace("full-uncal", "Full - Uncalibrated")
    
    return m


def load_data(args):
    
    args.is_wilds = is_wilds(args.dataset)

    if not args.is_wilds:

        cls_model = load_model(args)

        split = "mix_val"
        val_data_loader = get_dataloader(args.dataset, args.batch_size, pre_process=cls_model.preprocess, mix=True, n_samples=args.n_val_samples)
        print("Extracting val logits")
        val_dataset = extract_dataset(cls_model, val_data_loader)

        split = "test"

        if (args.test_split, args.model, split) in save_data:
            test_dataset = load_dataset_from_store(args.test_split, args.model, split)
            test_data_loader = None
        else:
            test_data_loader = get_dataloader(args.test_split, args.batch_size, pre_process=cls_model.preprocess, mix=False)
            test_dataset = extract_dataset(cls_model, test_data_loader, loader_type="hf")

            save_root = "../data/processed/{}/{}/{}_".format(args.test_split, args.model, split)
            torch.save(torch.Tensor(test_dataset.features), "{}features.pt".format(save_root))
            torch.save(torch.Tensor(test_dataset.logits), "{}logits.pt".format(save_root))
            torch.save(torch.Tensor(test_dataset.labels), "{}labels.pt".format(save_root))

    else:

        split = args.val_split
        if (args.dataset, args.model, split) in save_data:
            val_dataset = load_wilds_dataset(args, split)
            val_dataset.random_sample(n_samples=args.n_val_samples)
            val_data_loader = get_wilds_dataloader(val_dataset, args)
            # val_data_loader = get_dataloader(args.dataset, args.batch_size, pre_process=cls_model.preprocess, mix=True)
        else:  
            raise ValueError

        split = args.test_split

        if (args.dataset, args.model, split) in save_data:
            test_dataset = load_wilds_dataset(args, split)
            test_data_loader = get_wilds_dataloader(test_dataset, args)
        else:  
            raise ValueError

    test_dataset.random_sample(n_samples=args.n_test_samples)
    
    return val_dataset, val_data_loader, test_dataset, test_data_loader

    
    
def is_wilds(dataset):
    return dataset in ["rxrx", "camelyon17"]


def get_model_name(args, net_name):

    model_name = "split_{}_n_{}_net_{}_covwht_{}_seed_{}_scaling_{}_ep_{}".format(
        args.val_split, args.n_val_samples, net_name, args.coverage_weight, args.seed, args.scaling, args.n_epoch
    )
    if args.cov_loss == "quad":
        model_name += "_beta_{}".format(args.beta)
    if args.is_wilds and args.noising:
        model_name += "_noising"
    return model_name
    
    
def coverage_loss(selector, g, weights, coverage_weight, n, args, device):

    if selector.cov_loss == "bce":

        cov_loss = F.binary_cross_entropy_with_logits(g, torch.ones_like(g), reduction='sum')
        cov_loss *= (coverage_weight / n)

    elif selector.cov_loss == "quad":

        coverage = torch.tensor([args.beta], dtype=torch.float32, requires_grad=True, device=device)
        cov_loss = coverage_weight*torch.max((coverage-torch.mean(weights))**2, torch.tensor([0.0], dtype=torch.float32, requires_grad=True, device=device))
    else:
        raise ValueError

    return cov_loss


def selective_loss(selector, probs, labels, output, weights, args):

    preds = torch.argmax(probs,-1)
    targets = torch.Tensor(preds == labels).float()
    n = probs.shape[0]

    if selector.sel_loss == "bce":
        
        if args.scaling == "temp":
            output = torch.max(output, -1)[0]

        loss_fn = nn.BCELoss(reduction='none')
        ce_loss = loss_fn(output, targets)
        sel_loss = (ce_loss * weights).mean()

    elif selector.sel_loss == "mce":

        loss_fn = nn.CrossEntropyLoss(reduction='none')
        if args.scaling == "platt":
            ce_loss = loss_fn(probs, labels)
        else:
            ce_loss = loss_fn(output, labels)
        sel_loss = (ce_loss * weights).mean()

    elif selector.sel_loss == "smmce":
        
        if args.scaling == "temp":
            output = torch.max(output, -1)[0]

        sel_loss = compute_smmce_loss(
            outputs=output,
            targets=targets,
            weights=weights,
            pnorm=2
        )
        sel_loss *= (1 / np.sqrt(n))
    else:
        raise ValueError
        
    return sel_loss


def produce_confidence_baseline(h, test_dataset, baseline_results, args):

    if args.scaling == "platt":
        test_conf = h.calibrate(test_dataset.probs)
    else:
        logits_cal = h(torch.Tensor(test_dataset.logits).cuda())
        probs_cal = F.softmax(logits_cal, -1).detach().cpu().numpy()
        test_conf = np.max(probs_cal, -1)
        
    sorted_ind = np.argsort(-test_conf)

    test_probs_cal = test_conf[sorted_ind]
    test_labels = test_dataset.labels[sorted_ind]

    test_y_hat = np.argmax(test_dataset.probs[sorted_ind], -1)
    test_y_correct = np.array((test_y_hat==test_labels), dtype=int)

    X = []
    y = {"ece_1":[], "ece_2":[], "acc":[]}

    for i in range(args.low_quantile,105,5):

        r = int((i/100)*test_dataset.n)
        X.append(r/test_dataset.n)
        y["ece_1"].append(calc_score("ECE-1", test_probs_cal[:r], test_y_correct[:r]))
        y["ece_2"].append(calc_score("ECE-2", test_probs_cal[:r], test_y_correct[:r]))
        y["acc"].append(np.sum(test_y_correct[:r])/len(test_y_correct[:r]))

    baseline_results["conf"] = (X,y)
    
def produce_ood_baselines(val_dataset, h, test_dataset, baseline_results, args):
    
    def get_outlier_models(train_input_features):

        # For efficiency, we now subsample the train data.
        np.random.seed(1)
        outlier_models = []

        osvm = OneClassSVM(max_iter=10000)
        osvm.fit(train_input_features)
        outlier_models.append(("osvm", osvm, 1))

        isoforest = IsolationForest()
        isoforest.fit(train_input_features)
        outlier_models.append(("isoforest", isoforest, 1))

        return outlier_models
    
    outlier_models = get_outlier_models(val_dataset.features)
    
    if args.scaling == "platt":
        test_conf = h.calibrate(test_dataset.probs)
    else:
        logits_cal = h(torch.Tensor(test_dataset.logits).cuda())
        probs_cal = F.softmax(logits_cal, -1).detach().cpu().numpy()
        test_conf = np.max(probs_cal, -1)
    # test_conf = h.calibrate(test_dataset.probs)
    
    for name, mdl, sign in outlier_models:
        
        test_scores = sign*mdl.score_samples(test_dataset.features)

        sorted_ind = np.argsort(-test_scores)

        test_probs_cal = test_conf[sorted_ind]
        test_labels = test_dataset.labels[sorted_ind]

        test_y_hat = np.argmax(test_dataset.probs[sorted_ind], -1)
        test_y_correct = np.array((test_y_hat==test_labels), dtype=int)

        X = []
        y = {"ece_1":[], "ece_2":[], "acc":[]}

        for i in range(args.low_quantile,105,5):

            r = int((i/100)*test_dataset.n)
            X.append(r/test_dataset.n)
            y["ece_1"].append(calc_score("ECE-1", test_probs_cal[:r], test_y_correct[:r]))
            y["ece_2"].append(calc_score("ECE-2", test_probs_cal[:r], test_y_correct[:r]))
            y["acc"].append(np.sum(test_y_correct[:r])/len(test_y_correct[:r]))

        baseline_results[name] = (X,y)
        

def laplacian_kernel(values_i, values_j, width=0.2):
    """Compute the Laplace kernel k(_i, v_j) = exp(-|v_i - v_j| / width).
    Args:
        values_i: n-d matrix of "i" elements.
        values_j: n-d matrix of "j" elements.
        width: Laplacian kernel width parameter.
    Returns:
        n-d matrix of kernel values.
    """
    pairwise_dists = torch.abs(values_i - values_j)
    return torch.exp(-pairwise_dists.div(width))


def compute_smmce_loss(outputs, targets, weights, kernel_fn=None, pnorm=2):
    """Compute the non-normalized S-MMCE_u loss.
    \sum_{ij} |y_i - r_i|^q |y_j - r_j|^q g(x_i) g(x_j) k(r_i, r_j)
    Args:
        outputs: Confidence values r_i.
        targets: Target values y_i.
        weights: Selection weights g(x_i).
        kernel_fn: Callable function to compute k(r_i, r_j) over a matrix.
        pnorm: l_p parameter.
    Returns:
        Computed loss.
    """
    # |y - f(x)|^p
    calibration_error = torch.abs(targets.float() - outputs).pow(pnorm)

    # |x_i - f(x_i)|^p * |y_j - f(x_j)|^p
    pairwise_errors = torch.outer(calibration_error, calibration_error)

    # k(f(x_i), f(x_j))
    kernel_fn = kernel_fn if kernel_fn is not None else laplacian_kernel
    outputs_i = outputs.view(-1, 1).repeat(1, len(outputs))
    outputs_j = outputs.view(1, -1).repeat(len(outputs), 1)
    kernel_matrix = kernel_fn(outputs_i, outputs_j)

    # g(x_i) * g(x_j)
    weights = torch.outer(weights, weights)

    # Compute full matrix: error_i * error_j * g_i * g_j * k(i, j).
    matrix_values = pairwise_errors * kernel_matrix * weights

    # Here we *do not* do any normalization, as described in Eq. 15.
    smmce_loss = matrix_values.sum().pow(1 / pnorm)

    return smmce_loss


class ModelWithTemperature(nn.Module):
    """
    A thin decorator, which wraps a model with temperature scaling
    model (nn.Module):
        A classification neural network
        NB: Output of the neural network should be the classification logits,
            NOT the softmax (or log softmax)!
    """
    def __init__(self):
        super(ModelWithTemperature, self).__init__()
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, logits):
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        return logits / temperature

    # This function probably should live outside of this class, but whatever
    def set_temperature(self, val_dataset):
        """
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        valid_loader (DataLoader): validation set loader
        """
        self.cuda()
        nll_criterion = nn.CrossEntropyLoss().cuda()
        ece_criterion = _ECELoss().cuda()

        # First: collect all the logits and labels for the validation set
        logits = torch.Tensor(val_dataset.logits).cuda()
        labels = torch.LongTensor(val_dataset.labels).cuda()

        # Calculate NLL and ECE before temperature scaling
        before_temperature_nll = nll_criterion(logits, labels).item()
        before_temperature_ece = ece_criterion(logits, labels).item()
        print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        # Next: optimize the temperature w.r.t. NLL
        optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=200)

        def eval():
            optimizer.zero_grad()
            loss = nll_criterion(self.temperature_scale(logits), labels)
            loss.backward()
            return loss
        optimizer.step(eval)

        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
        after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
        print('Optimal temperature: %.3f' % self.temperature.item())
        print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))

        return self


class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece



def train_h(dataset):
    
    mtd = PlattTopCalibrator
    h = mtd(dataset.probs.shape[0], num_bins=10)
    h.train_calibration(dataset.probs, dataset.labels)
    return h


def calc_score(scoring_metric, probs, labels):

    if scoring_metric == "ECE-1":
        return cal.lower_bound_scaling_ce(probs, labels, p=1, debias=False, num_bins=15,
                      mode="top-label")
    elif scoring_metric == "ECE-2":
        return cal.lower_bound_scaling_ce(probs, labels, p=2, debias=True, num_bins=15,
                      mode="top-label")
    elif scoring_metric == "Brier":
        return brier_multi(probs, labels)
    else:
        return ValueError
    
    
class PlattTopCalibrator:

    def __init__(self, num_calibration, num_bins):
        self._num_calibration = num_calibration
        self._num_bins = num_bins

    def train_calibration(self, probs, labels):
        assert(len(probs) >= self._num_calibration)
        predictions = np.argmax(probs, -1)
        top_probs = np.max(probs, -1)
        correct = (predictions == labels)
        self._platt, self.clf = get_platt_scaler(
            top_probs, correct, True)

    def calibrate(self, probs):
        return self._platt(np.max(probs, -1))
    

def get_platt_scaler(model_probs, labels, get_clf=False):
    clf = LogisticRegression(C=1e10, solver='lbfgs')
    eps = 1e-12
    model_probs = model_probs.astype(dtype=np.float64)
    model_probs = np.expand_dims(model_probs, axis=-1)
    model_probs = np.clip(model_probs, eps, 1 - eps)
    model_probs = np.log(model_probs / (1 - model_probs))
    clf.fit(model_probs, labels)
    def calibrator(probs):
        x = np.array(probs, dtype=np.float64)
        x = np.clip(x, eps, 1 - eps)
        x = np.log(x / (1 - x))
        x = x * clf.coef_[0] + clf.intercept_
        output = 1 / (1 + np.exp(-x))
        return output
    if get_clf:
        return calibrator, clf
    return calibrator